Diamonds Dataset

R’s built-in Diamonds dataset contains about five thousand observations of diamond pieces. Each observation records ten attributes such as price (in USD), carat (weight), quality of the cut, and so on. I could not find what year it was collected.

# 1. load and quick look

data("diamonds")
View(diamonds)

?diamonds

summary(diamonds)
##      carat               cut        color        clarity          depth      
##  Min.   :0.2000   Fair     : 1610   D: 6775   SI1    :13065   Min.   :43.00  
##  1st Qu.:0.4000   Good     : 4906   E: 9797   VS2    :12258   1st Qu.:61.00  
##  Median :0.7000   Very Good:12082   F: 9542   SI2    : 9194   Median :61.80  
##  Mean   :0.7979   Premium  :13791   G:11292   VS1    : 8171   Mean   :61.75  
##  3rd Qu.:1.0400   Ideal    :21551   H: 8304   VVS2   : 5066   3rd Qu.:62.50  
##  Max.   :5.0100                     I: 5422   VVS1   : 3655   Max.   :79.00  
##                                     J: 2808   (Other): 2531                  
##      table           price             x                y         
##  Min.   :43.00   Min.   :  326   Min.   : 0.000   Min.   : 0.000  
##  1st Qu.:56.00   1st Qu.:  950   1st Qu.: 4.710   1st Qu.: 4.720  
##  Median :57.00   Median : 2401   Median : 5.700   Median : 5.710  
##  Mean   :57.46   Mean   : 3933   Mean   : 5.731   Mean   : 5.735  
##  3rd Qu.:59.00   3rd Qu.: 5324   3rd Qu.: 6.540   3rd Qu.: 6.540  
##  Max.   :95.00   Max.   :18823   Max.   :10.740   Max.   :58.900  
##                                                                   
##        z         
##  Min.   : 0.000  
##  1st Qu.: 2.910  
##  Median : 3.530  
##  Mean   : 3.539  
##  3rd Qu.: 4.040  
##  Max.   :31.800  
## 
sum(is.na(diamonds)) # no missing values
## [1] 0
hist(diamonds$carat) # right skewed with a few outliers

hist(diamonds$price) # right skewed

hist(diamonds$depth) # mostly normal

# it makes sense for measures suggesting quality (price, weight) to be right skewed => the better the diamond the rarer (and vice versa)
# you can tell the data was well collected or pre cleaned

Explore Relationships

# 2. Any Ideas?

plot(diamonds$depth, diamonds$price) # no clear relationship

plot(diamonds$table, diamonds$price) # same

dia_bycut <- diamonds %>% 
  group_by(cut) %>% 
  select(price, carat) %>% 
  summarise_all(list(mean = mean, min = min, max =max))

kable(dia_bycut)
cut price_mean carat_mean price_min carat_min price_max carat_max
Fair 4358.758 1.0461366 337 0.22 18574 5.01
Good 3928.864 0.8491847 327 0.23 18788 3.01
Very Good 3981.760 0.8063814 336 0.20 18818 4.00
Premium 4584.258 0.8919549 326 0.20 18823 4.01
Ideal 3457.542 0.7028370 326 0.20 18806 3.50
histdepth_bycut <- diamonds %>%
  ggplot(aes(x = depth)) +
  geom_histogram(binwidth = 0.5) +
  facet_wrap(~ cut) +
  xlab("Depth") + ylab("Count")

histdepth_bycut

histprice_bycut <- diamonds %>%
  ggplot(aes(x = price, fill = cut)) +
  geom_histogram(binwidth = 100) +
  xlab("Price") + ylab("Count")

histprice_bycut

scatter_price <- diamonds %>%
  ggplot( aes( x=carat, y=price, 
              color=cut)) + 
  geom_point(alpha=0.7) + 
  xlab("Carat") + ylab("Price")+
  guides(fill = guide_legend(title = "Cut"))

scatter_price

linear_price <- lm(price ~ ., data = diamonds)
summary(linear_price)
## 
## Call:
## lm(formula = price ~ ., data = diamonds)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -21376.0   -592.4   -183.5    376.4  10694.2 
## 
## Coefficients:
##              Estimate Std. Error  t value Pr(>|t|)    
## (Intercept)  5753.762    396.630   14.507  < 2e-16 ***
## carat       11256.978     48.628  231.494  < 2e-16 ***
## cut.L         584.457     22.478   26.001  < 2e-16 ***
## cut.Q        -301.908     17.994  -16.778  < 2e-16 ***
## cut.C         148.035     15.483    9.561  < 2e-16 ***
## cut^4         -20.794     12.377   -1.680  0.09294 .  
## color.L     -1952.160     17.342 -112.570  < 2e-16 ***
## color.Q      -672.054     15.777  -42.597  < 2e-16 ***
## color.C      -165.283     14.725  -11.225  < 2e-16 ***
## color^4        38.195     13.527    2.824  0.00475 ** 
## color^5       -95.793     12.776   -7.498 6.59e-14 ***
## color^6       -48.466     11.614   -4.173 3.01e-05 ***
## clarity.L    4097.431     30.259  135.414  < 2e-16 ***
## clarity.Q   -1925.004     28.227  -68.197  < 2e-16 ***
## clarity.C     982.205     24.152   40.668  < 2e-16 ***
## clarity^4    -364.918     19.285  -18.922  < 2e-16 ***
## clarity^5     233.563     15.752   14.828  < 2e-16 ***
## clarity^6       6.883     13.715    0.502  0.61575    
## clarity^7      90.640     12.103    7.489 7.06e-14 ***
## depth         -63.806      4.535  -14.071  < 2e-16 ***
## table         -26.474      2.912   -9.092  < 2e-16 ***
## x           -1008.261     32.898  -30.648  < 2e-16 ***
## y               9.609     19.333    0.497  0.61918    
## z             -50.119     33.486   -1.497  0.13448    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1130 on 53916 degrees of freedom
## Multiple R-squared:  0.9198, Adjusted R-squared:  0.9198 
## F-statistic: 2.688e+04 on 23 and 53916 DF,  p-value: < 2.2e-16
set.seed (11)
diamonds$train <- sample(c(TRUE , FALSE), nrow(diamonds), replace = TRUE)
# drawing a sample of 53940 with replacement 

dia_train <- filter(diamonds, train)
dia_test <- filter(diamonds, !train)

lm1 <- lm(price ~ carat + cut, data = dia_train)
lm2 <- lm(price ~ carat + cut + depth + table, data = dia_train)
lm3 <- lm(price ~ . - x - y -z, data = dia_train)
lm4 <- lm(price ~ . , data = dia_train)

# Out of sample mean squared error

lm1hat <-  predict(lm1, newdata = dia_test)
mean( (dia_test$price - lm1hat)^2 )
## [1] 2295292
# we can automate the process with a function

mse <- function(model) {
  yhat <- predict(model, newdata = dia_test)
  sq_error <- (dia_test$price - yhat)^2
  mean(sq_error)
}

mse(lm1)
## [1] 2295292
mse(lm2)
## [1] 2285418
mse(lm3)
## [1] 1337119
mse(lm4)
## [1] 1402004
# apply our function to multiple models at once
lapply(list(lm1, lm2, lm3, lm4), mse)
## [[1]]
## [1] 2295292
## 
## [[2]]
## [1] 2285418
## 
## [[3]]
## [1] 1337119
## 
## [[4]]
## [1] 1402004
# we can save it into a table
mse_table <- as.data.frame(lapply(list(lm1, lm2, lm3, lm4), mse),
                           col.names = c("lm1", "lm2", "lm3", "lm4"))
library(glmnet)
## Loading required package: Matrix
## 
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
## 
##     expand, pack, unpack
## Loaded glmnet 4.1-3
x1 <- model.matrix(~ ., select(dia_train, -price))
x2 <- model.matrix(price ~ ., dia_train)
y <- dia_train$price

lambda_grid <- 10^seq(3, -3, by = -.1) 
# it's important to understand what this grid does. we are varying lambda
# from 0.01 to 1000. the reason we sequence inside the power is because 
# it will make the grid more sensitive when we ideally want it to be and less
# sensitive otherwise. it will very by small increments near 0.01 and by
# large increments near 1000

ridge <- glmnet(x2, y, alpha = 0, lambda = lambda_grid)
summary(ridge)
##           Length Class     Mode   
## a0          61   -none-    numeric
## beta      1525   dgCMatrix S4     
## df          61   -none-    numeric
## dim          2   -none-    numeric
## lambda      61   -none-    numeric
## dev.ratio   61   -none-    numeric
## nulldev      1   -none-    numeric
## npasses      1   -none-    numeric
## jerr         1   -none-    numeric
## offset       1   -none-    logical
## call         5   -none-    call   
## nobs         1   -none-    numeric
# in the sequence we went from lambda high to lambda low
# then parameters of first models should be smaller:
coef(ridge)[, 1] # more shrinked
##   (Intercept)   (Intercept)         carat         cut.L         cut.Q 
## -9358.7185818     0.0000000  3121.6705399   428.6346699  -153.3709702 
##         cut.C         cut^4       color.L       color.Q       color.C 
##    -0.1020196     0.6200210  -921.9242316  -284.9004911   -69.7256961 
##       color^4       color^5       color^6     clarity.L     clarity.Q 
##    15.5184250   -47.4734964   -57.1629559  2249.9796100  -729.2600017 
##     clarity.C     clarity^4     clarity^5     clarity^6     clarity^7 
##   -19.8564166    52.8007675   -29.9372590    35.9108281   114.8650951 
##         depth         table             x             y             z 
##    -3.4973177    -6.3627280   706.7594218   725.0183722   841.3276792 
##     trainTRUE 
##     0.0000000
coef(ridge)[, 40] # less shrinked
##  (Intercept)  (Intercept)        carat        cut.L        cut.Q        cut.C 
##  4573.414644     0.000000 10869.477144   599.517551  -292.118569   103.388704 
##        cut^4      color.L      color.Q      color.C      color^4      color^5 
##   -31.579700 -1933.089636  -665.030232  -171.691705     9.240034   -85.324238 
##      color^6    clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5 
##   -55.730843  4112.651184 -1927.661068   992.445995  -387.185503   225.347400 
##    clarity^6    clarity^7        depth        table            x            y 
##     6.110287    96.873016   -55.511042   -25.238539 -1419.531759   587.309083 
##            z    trainTRUE 
##   -65.532631     0.000000
summary(lm1)
## 
## Call:
## lm(formula = price ~ carat + cut, data = dia_train)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -17389.7   -788.4    -38.9    511.1  12711.9 
## 
## Coefficients:
##             Estimate Std. Error  t value Pr(>|t|)    
## (Intercept) -2678.31      21.72 -123.297  < 2e-16 ***
## carat        7839.95      19.84  395.196  < 2e-16 ***
## cut.L        1246.49      36.44   34.207  < 2e-16 ***
## cut.Q        -557.04      32.34  -17.225  < 2e-16 ***
## cut.C         365.52      28.41   12.864  < 2e-16 ***
## cut^4          79.67      22.92    3.476  0.00051 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 1508 on 26999 degrees of freedom
## Multiple R-squared:  0.8543, Adjusted R-squared:  0.8543 
## F-statistic: 3.167e+04 on 5 and 26999 DF,  p-value: < 2.2e-16
# the glmnet package can do cross validation on its own and pick best lambda 
cv_ridge <- cv.glmnet(x2, y, alpha = 0, lambda = lambda_grid)
plot(cv_ridge)

ridge_star <- cv_ridge$lambda.min
# suggests 0.001 is best lambda (also the lambda we start with)
ridge_hat_star <- predict(ridge , s = ridge_star, model.matrix(price ~ ., dia_test))
mse_table$ridge_0.001  <- mean (( ridge_hat_star - dia_test$price)^2)

# however, this is the best lambda picked using only the training dataset
# does not mean it will perform best in the test. 
# plus, it's dig-deeper-worthy that the smallest is the best lambda

# we can change the value of lambda by hand within the predict function
# s represents lambda. when s is specified the predict will give a vector
ridge_hat1  <- predict(ridge  , s = 1, model.matrix(price ~ ., dia_test))
ridge_hat100 <- predict(ridge , s = 100, model.matrix(price ~ ., dia_test))
mean( (ridge_hat1  - dia_test$price)^2 ) 
## [1] 1298093
# see lambda = 1 preformed better than lambda = 0.001 that cv suggested
mean( (ridge_hat100  - dia_test$price)^2 )
## [1] 1446700
mse_table$ridge_1    <-  mean( (ridge_hat1  - dia_test$price)^2 )
mse_table$ridge_100  <-  mean( (ridge_hat100  - dia_test$price)^2 )

# when s is null, it will give a matrix (each column for a different lambda)
ridge_hat_all <- predict(ridge  , s = NULL, model.matrix(price ~ ., dia_test))
head(ridge_hat_all)
##           s0         s1         s2         s3         s4         s5         s6
## 1 -1963.1550 -2053.6863 -2117.1977 -2154.8000 -2169.1583 -2162.2544 -2137.8246
## 2 -1722.8065 -1778.2402 -1808.3148 -1814.0126 -1797.9644 -1761.7860 -1709.1989
## 3  -901.3231  -941.6041  -961.4892  -961.0721  -942.2261  -905.7931  -854.8382
## 4 -2221.4396 -2445.5926 -2639.6673 -2804.6815 -2943.1659 -3057.4164 -3150.8351
## 5 -1528.8689 -1618.1994 -1682.7371 -1723.7059 -1743.6107 -1744.2762 -1729.0198
## 6 -1577.8931 -1650.5645 -1700.0384 -1727.4441 -1735.0679 -1724.9148 -1700.2573
##           s7         s8         s9        s10       s11        s12        s13
## 1 -2098.8453 -2048.6617 -1990.9427 -1929.3347 -1866.447 -1805.4950 -1748.7768
## 2 -1642.9983 -1566.5231 -1483.6282 -1398.3432 -1313.278 -1232.1784 -1157.6720
## 3  -791.5077  -718.5689  -639.4267  -557.6252  -475.945  -397.7218  -325.4761
## 4 -3226.1715 -3286.1836 -3333.5526 -3370.6839 -3399.506 -3421.9207 -3439.4395
## 5 -1700.4673 -1661.5659 -1615.5464 -1565.6560 -1514.131 -1463.8726 -1416.9440
## 6 -1663.9733 -1619.2059 -1569.2072 -1517.3185 -1465.081 -1415.4060 -1370.0780
##          s14        s15        s16        s17         s18         s19
## 1 -1695.6042 -1647.7884 -1606.0135 -1570.2100 -1539.65914 -1513.00478
## 2 -1088.6012 -1027.0366  -973.7140  -928.4324  -890.19119  -857.25607
## 3  -258.3487  -198.3281  -146.1429  -101.6472   -63.92877   -31.34733
## 4 -3452.5984 -3462.6550 -3470.4463 -3476.5303 -3481.22729 -3484.63602
## 5 -1372.7454 -1332.9121 -1298.0934 -1268.2678 -1242.83972 -1220.66180
## 6 -1328.1434 -1290.9394 -1258.9233 -1231.9239 -1209.26562 -1189.84356
##            s20         s21         s22         s23         s24         s25
## 1 -1490.149949 -1470.78066 -1454.72344 -1441.62076 -1430.66693 -1421.55202
## 2  -829.433525  -806.29879  -787.58171  -772.75831  -760.84322  -751.40559
## 3    -3.712143    19.39832    38.25075    53.34083    65.63038    75.53441
## 4 -3487.136683 -3488.99096 -3490.43910 -3491.62010 -3492.52681 -3493.23890
## 5 -1201.679111 -1185.64781 -1172.43655 -1161.74200 -1152.88449 -1145.60132
## 6 -1173.543014 -1160.11498 -1149.39967 -1141.06484 -1134.51643 -1129.49011
##           s26         s27         s28         s29        s30        s31
## 1 -1414.10196 -1407.96423 -1402.91255 -1398.70990 -1395.3327 -1392.5313
## 2  -744.09379  -738.39800  -733.96793  -730.49618  -727.8437  -725.7528
## 3    83.36116    89.58679    94.53521    98.50498   101.6030   104.0967
## 4 -3493.82877 -3494.30535 -3494.68968 -3494.99052 -3495.2445 -3495.4408
## 5 -1139.72635 -1134.94808 -1131.06380 -1127.87137 -1125.3334 -1123.2476
## 6 -1125.73998 -1122.93844 -1120.85798 -1119.31291 -1118.1924 -1117.3568
##          s32        s33        s34        s35        s36        s37        s38
## 1 -1390.2892 -1388.4396 -1386.9371 -1385.8067 -1384.7873 -1384.0324 -1383.4376
## 2  -724.1504  -722.8829  -721.8939  -721.1663  -720.5297  -720.0717  -719.7154
## 3   106.0452   107.6144   108.8615   109.7909   110.6118   111.2127   111.6831
## 4 -3495.6050 -3495.7314 -3495.8322 -3495.9260 -3495.9825 -3496.0386 -3496.0846
## 5 -1121.5925 -1120.2367 -1119.1429 -1118.3250 -1117.5882 -1117.0468 -1116.6213
## 6 -1116.7510 -1116.2977 -1115.9652 -1115.7319 -1115.5346 -1115.4032 -1115.3039
##          s39        s40        s41        s42        s43        s44        s45
## 1 -1382.9239 -1382.5639 -1382.2761 -1381.9881 -1381.8472 -1381.6333 -1381.4925
## 2  -719.4124  -719.2031  -719.0357  -718.8691  -718.7903  -718.6664  -718.5877
## 3   112.0850   112.3657   112.5901   112.8125   112.9216   113.0859   113.1929
## 4 -3496.1150 -3496.1470 -3496.1721 -3496.1833 -3496.2021 -3496.2073 -3496.2120
## 5 -1116.2537 -1115.9982 -1115.7938 -1115.5875 -1115.4900 -1115.3355 -1115.2356
## 6 -1115.2204 -1115.1666 -1115.1230 -1115.0779 -1115.0585 -1115.0259 -1115.0056
##          s46        s47        s48        s49        s50        s51        s52
## 1 -1381.3532 -1381.2156 -1381.0807 -1380.9491 -1380.8211 -1380.6971 -1380.5772
## 2  -718.5101  -718.4349  -718.3631  -718.2954  -718.2321  -718.1736  -718.1197
## 3   113.2978   113.4000   113.4986   113.5929   113.6826   113.7676   113.8480
## 4 -3496.2133 -3496.2104 -3496.2044 -3496.1962 -3496.1864 -3496.1756 -3496.1642
## 5 -1115.1359 -1115.0372 -1114.9405 -1114.8463 -1114.7552 -1114.6673 -1114.5828
## 6 -1114.9860 -1114.9674 -1114.9504 -1114.9357 -1114.9236 -1114.9142 -1114.9077
##          s53        s54        s55        s56        s57        s58        s59
## 1 -1380.4613 -1380.3493 -1380.2411 -1380.1365 -1380.0353 -1379.9373 -1379.8422
## 2  -718.0704  -718.0255  -717.9848  -717.9481  -717.9150  -717.8854  -717.8589
## 3   113.9238   113.9952   114.0624   114.1257   114.1854   114.2416   114.2948
## 4 -3496.1525 -3496.1408 -3496.1291 -3496.1177 -3496.1065 -3496.0957 -3496.0852
## 5 -1114.5016 -1114.4236 -1114.3488 -1114.2771 -1114.2081 -1114.1418 -1114.0779
## 6 -1114.9039 -1114.9028 -1114.9043 -1114.9082 -1114.9143 -1114.9225 -1114.9326
##          s60
## 1 -1379.7499
## 2  -717.8353
## 3   114.3450
## 4 -3496.0751
## 5 -1114.0163
## 6 -1114.9444
# let's automate out of sample mse calculation and finding best lambda
bestmse_shrink <- function(model, hat_matrix) {
  shrink_hats <- predict(model, s = NULL, model.matrix(price ~ ., dia_test))
  mses <- c()
  for (i in 1:ncol(shrink_hats)) {
    mses <- c(mses, mean( (hat_matrix[,i]  - dia_test$price)^2 ) )
  }
  index_min <- c(which.min(mses), min(mses))
  return(index_min)
}

# first is index of the best lambda second is the best out of sample mse it gives
best_ridge <- bestmse_shrink(ridge, ridge_hat_all) 
# so it's the 24th lambda in the list that gives best omse:
lambda_grid[24]
## [1] 5.011872
mse_table$ridge_5 <-  best_ridge[2]
#
lasso <- glmnet(x2, y, alpha = 1, lambda = lambda_grid)
summary(lasso)
##           Length Class     Mode   
## a0          61   -none-    numeric
## beta      1525   dgCMatrix S4     
## df          61   -none-    numeric
## dim          2   -none-    numeric
## lambda      61   -none-    numeric
## dev.ratio   61   -none-    numeric
## nulldev      1   -none-    numeric
## npasses      1   -none-    numeric
## jerr         1   -none-    numeric
## offset       1   -none-    logical
## call         5   -none-    call   
## nobs         1   -none-    numeric
coef(lasso)[, 1]  # it eliminated out all variables except carat
## (Intercept) (Intercept)       carat       cut.L       cut.Q       cut.C 
##   -548.5598      0.0000   5601.3582      0.0000      0.0000      0.0000 
##       cut^4     color.L     color.Q     color.C     color^4     color^5 
##      0.0000      0.0000      0.0000      0.0000      0.0000      0.0000 
##     color^6   clarity.L   clarity.Q   clarity.C   clarity^4   clarity^5 
##      0.0000      0.0000      0.0000      0.0000      0.0000      0.0000 
##   clarity^6   clarity^7       depth       table           x           y 
##      0.0000      0.0000      0.0000      0.0000      0.0000      0.0000 
##           z   trainTRUE 
##      0.0000      0.0000
coef(lasso)[, 10] # kept carat and some color, clarity and one cut aspects
## (Intercept) (Intercept)       carat       cut.L       cut.Q       cut.C 
## -2677.82669     0.00000  8194.94827   148.06335     0.00000     0.00000 
##       cut^4     color.L     color.Q     color.C     color^4     color^5 
##     0.00000 -1032.62945   -32.59695     0.00000     0.00000     0.00000 
##     color^6   clarity.L   clarity.Q   clarity.C   clarity^4   clarity^5 
##     0.00000  2601.03959  -685.06392     0.00000     0.00000     0.00000 
##   clarity^6   clarity^7       depth       table           x           y 
##     0.00000     0.00000     0.00000     0.00000     0.00000     0.00000 
##           z   trainTRUE 
##     0.00000     0.00000
coef(lasso)[, 40] # eliminated none but still shrinked
##  (Intercept)  (Intercept)        carat        cut.L        cut.Q        cut.C 
##  4601.497433     0.000000 10869.007600   600.725554  -294.378963   107.382448 
##        cut^4      color.L      color.Q      color.C      color^4      color^5 
##   -28.527528 -1932.412433  -664.244657  -170.883619     9.281562   -85.184584 
##      color^6    clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5 
##   -55.618185  4113.154963 -1926.805899   991.688260  -386.565878   224.745655 
##    clarity^6    clarity^7        depth        table            x            y 
##     5.853690    96.305213   -55.885453   -25.298601 -1318.067353   484.915506 
##            z    trainTRUE 
##   -64.325110     0.000000
# use cv to find best
cv_lasso <- cv.glmnet(x2, y, alpha = 1, lambda = lambda_grid)
plot(cv_lasso)

lasso_star <- cv_ridge$lambda.min
lasso_hat_star <- predict(lasso, s = lasso_star, model.matrix(price ~ ., dia_test))
mse_table$lasso_0.001  <- mean (( lasso_hat_star - dia_test$price)^2)

# check other lamdas for out of sample performance
lasso_hat1    <- predict(lasso , s = 1, model.matrix(price ~ ., dia_test))
lasso_hat100  <- predict(lasso , s = 100, model.matrix(price ~ ., dia_test))

mse_table$lasso_1   <-  mean (( lasso_hat1  - dia_test$price)^2)
mse_table$lasso_100 <-  mean (( lasso_hat100  - dia_test$price)^2)

# let's check them all
lasso_hat_all <- predict(lasso  , s = NULL, model.matrix(price ~ ., dia_test))
best_lasso <- bestmse_shrink(lasso, lasso_hat_all) 
best_lasso
## [1]      31 1272705
lambda_grid[31]
## [1] 1
# 31st lambda which is one and we already saved it into mse table
# mse_table$lasso_1 <-  best_lasso[2]

mse_all <- as.data.frame(t(mse_table))
min(mse_all$V1)
## [1] 1272705
kable(mse_all)
V1
lm1 2295292
lm2 2285418
lm3 1337119
lm4 1402004
ridge_0.001 1323918
ridge_1 1298093
ridge_100 1446700
ridge_5 1281099
lasso_0.001 1316549
lasso_1 1272705
lasso_100 1504542
coef(lasso)[, 31] # let's check our champion coefficients
##  (Intercept)  (Intercept)        carat        cut.L        cut.Q        cut.C 
##  4578.747552     0.000000 10802.133061   604.230651  -301.421697   122.545747 
##        cut^4      color.L      color.Q      color.C      color^4      color^5 
##   -12.704798 -1924.422334  -657.383713  -165.748616     8.984044   -83.355879 
##      color^6    clarity.L    clarity.Q    clarity.C    clarity^4    clarity^5 
##   -54.466311  4108.124757 -1913.183733   980.190071  -379.827149   218.440085 
##    clarity^6    clarity^7        depth        table            x            y 
##     3.934344    92.900303   -56.953337   -25.249497  -834.788004    16.013504 
##            z    trainTRUE 
##   -47.211912     0.000000
# check scale of the variables maybe can normalize to make
# the shrinkage more reliable

normalize <- function(x) {
  normal_x <- (x - mean(x)) / sd(x)
  return(normal_x)
}

# normalize(diamonds$carat)
# names(diamonds)

dia_norm <- diamonds %>% 
    mutate(
       norm_carat = normalize(carat),
       norm_depth = normalize(depth),
       norm_table = normalize(table),
       price_norm = normalize(price),
       norm_x = normalize(x),
       norm_y = normalize(y),
       norm_z = normalize(z)
       )

# if we build these models after having normalized, I think we'd expect
# to shrink depth and table more and carat less. 
# you can try and check yourself
plot(diamonds$carat, diamonds$price)

plot((diamonds$carat)^2, diamonds$price)

plot(log(diamonds$carat), diamonds$price)

plot(log(diamonds$carat), log(diamonds$price))

scatter_lnprice <- diamonds %>%
  ggplot( aes( x=log(carat), y=log(price), 
              color=cut)) + 
  geom_point(alpha=0.7) + 
  xlab("Natural Log of Carat") + ylab("Natural Log of Price")+
  guides(fill = guide_legend(title = "Cut"))

scatter_lnprice

# now let's try some models with logs

diamonds_log <- diamonds %>% 
  mutate(ln_carat = log(carat), ln_price = log(price))  %>% 
  select(-carat)

dialog_train <- filter(diamonds_log, train)
dialog_test <- filter(diamonds_log, !train)

lm3_log <- lm(ln_price ~ . - x - y -z, data = dialog_train)
lm4_log <- lm(ln_price ~ . , data = dialog_train)

# Out of sample mean squared error
# remember the predictions here are log prices so exponentiate
lm3_loghat <-  exp(predict(lm3_log, newdata = dialog_test))
## Warning in predict.lm(lm3_log, newdata = dialog_test): prediction from a rank-
## deficient fit may be misleading
mean( (dia_test$price - lm3_loghat)^2 )
## [1] 768614.7
# the dia_test has the same test prices as dialog_test just not logged
lm4_loghat <-  exp(predict(lm4_log, newdata = dialog_test))
## Warning in predict.lm(lm4_log, newdata = dialog_test): prediction from a rank-
## deficient fit may be misleading
mean( (dia_test$price - lm4_loghat)^2 ) 
## [1] 658183.5
# these are much better mse's than all the other models we had
# you can try running a log log regression for ridge and lasso now

I hope this was helpful yall. Shoot me R or assignment questions whenever. See you next review session.

Peace,

Lutfi